/*
 * Copyright 2025 Adobe. All rights reserved.
 * This file is licensed to you under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License. You may obtain a copy
 * of the License at http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software distributed under
 * the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS
 * OF ANY KIND, either express or implied. See the License for the specific language
 * governing permissions and limitations under the License.
 */

// this .cpp provides implementations for functions defined in those headers:
#include <lagrange/io/save_mesh_obj.h>
#include <lagrange/io/save_scene_obj.h>
#include <lagrange/io/save_simple_scene_obj.h>

#include <lagrange/Attribute.h>
#include <lagrange/Logger.h>
#include <lagrange/SurfaceMeshTypes.h>
#include <lagrange/foreach_attribute.h>
#include <lagrange/image_io/save_image.h>
#include <lagrange/io/api.h>
#include <lagrange/scene/SceneTypes.h>
#include <lagrange/scene/SimpleSceneTypes.h>
#include <lagrange/utils/assert.h>

#include <fstream>
#include <functional>
#include <ostream>
#include <set>

// clang-format off
#include <lagrange/utils/warnoff.h>
#include <spdlog/fmt/ostr.h>
#include <lagrange/utils/warnon.h>
// clang-format on

namespace lagrange {
namespace io {

// =====================================
// Helper functions for OBJ writing
// =====================================

namespace {

template <typename Scalar, typename Index>
void write_obj_header(std::ostream& output_stream, Index num_vertices, Index num_facets)
{
    fmt::print(
        output_stream,
        R"(####
#
# OBJ File Generated by Lagrange
#
####
#
# Vertices: {}
# Faces: {}
#
####

)",
        num_vertices,
        num_facets);
}

template <typename Scalar, typename Index>
struct AttributeWriteResult
{
    std::string found_uv_name; // Name of the UV attribute that was written
    std::string found_normal_name; // Name of the normal attribute that was written
    span<const Index> uv_indices; // Index mapping for UV coordinates (points to existing data)
    span<const Index> normal_indices; // Index mapping for normals (points to existing data or
                                      // normal_index_buffer)

    // May be empty or may contain data for custom normal index mappings (facet/corner attributes).
    // normal_indices above may point to this buffer, this ensures that the data is still allocated.
    std::vector<Index> normal_index_buffer;

    // Number of attribute values written (for offset calculation)
    Index uv_values_written = 0;
    Index normal_values_written = 0;
};

template <typename Scalar, typename Index>
AttributeWriteResult<Scalar, Index> write_mesh_attributes(
    std::ostream& output_stream,
    const SurfaceMesh<Scalar, Index>& mesh,
    const SaveOptions& options)
{
    AttributeWriteResult<Scalar, Index> result;

    seq_foreach_named_attribute_read(mesh, [&](std::string_view name, auto&& attr) {
        using AttributeType = std::decay_t<decltype(attr)>;

        // TODO: change this for the attribute visitor that takes id and simplify this block.
        if (options.output_attributes == SaveOptions::OutputAttributes::SelectedOnly) {
            AttributeId id = mesh.get_attribute_id(name);
            if (std::find(
                    options.selected_attributes.begin(),
                    options.selected_attributes.end(),
                    id) == options.selected_attributes.end()) {
                return;
            }
        }

        if (attr.get_usage() == AttributeUsage::UV) {
            if (result.found_uv_name.empty()) {
                result.found_uv_name = name;
            } else {
                logger().warn(
                    "Found multiple UV attributes. This is not supported. '{}' is saved, '{}' is "
                    "skipped",
                    result.found_uv_name,
                    name);
                return;
            }

            const Attribute<typename AttributeType::ValueType>* values = nullptr;
            if constexpr (AttributeType::IsIndexed) {
                values = &attr.values();
                result.uv_indices = attr.indices().get_all();
            } else {
                values = &attr;
                result.uv_indices = mesh.get_corner_to_vertex().get_all();
            }
            la_runtime_assert(attr.get_num_channels() == 2);
            result.uv_values_written = static_cast<Index>(values->get_num_elements());
            for (Index vt = 0; vt < values->get_num_elements(); ++vt) {
                auto p = values->get_row(vt);
                fmt::print(output_stream, "vt {} {}\n", p[0], p[1]);
            }
        }

        if (attr.get_usage() == AttributeUsage::Normal) {
            if (result.found_normal_name.empty()) {
                result.found_normal_name = name;
            } else {
                logger().warn(
                    "Found multiple Normal attributes. This is not supported. '{}' is saved, '{}' "
                    "is skipped",
                    result.found_normal_name,
                    name);
                return;
            }

            const Attribute<typename AttributeType::ValueType>* values = nullptr;
            if constexpr (AttributeType::IsIndexed) {
                values = &attr.values();
                result.normal_indices = attr.indices().get_all();
            } else if (attr.get_element_type() == AttributeElement::Vertex) {
                values = &attr;
                result.normal_indices = mesh.get_corner_to_vertex().get_all();
            } else if (attr.get_element_type() == AttributeElement::Facet) {
                values = &attr;
                result.normal_index_buffer.resize(mesh.get_num_corners());
                for (Index ci = 0; ci < mesh.get_num_corners(); ci++) {
                    result.normal_index_buffer[ci] = mesh.get_corner_facet(ci);
                }
                result.normal_indices = result.normal_index_buffer;
            } else if (attr.get_element_type() == AttributeElement::Corner) {
                values = &attr;
                result.normal_index_buffer.resize(mesh.get_num_corners());
                for (Index ci = 0; ci < mesh.get_num_corners(); ci++) {
                    result.normal_index_buffer[ci] = ci;
                }
                result.normal_indices = result.normal_index_buffer;
            } else {
                logger().warn(
                    "Skipping normal attribute '{}' due to unsupported element type",
                    result.found_normal_name);
                result.found_normal_name.clear();
                return;
            }
            la_runtime_assert(attr.get_num_channels() == 3);
            result.normal_values_written = static_cast<Index>(values->get_num_elements());
            for (Index vn = 0; vn < values->get_num_elements(); ++vn) {
                auto p = values->get_row(vn);
                fmt::print(output_stream, "vn {} {} {}\n", p[0], p[1], p[2]);
            }
        }
    });

    return result;
}

template <typename Scalar, typename Index, int Dim>
void write_mesh_vertices(
    std::ostream& output_stream,
    const SurfaceMesh<Scalar, Index>& mesh,
    const Eigen::Transform<Scalar, Dim, Eigen::Affine>& transform =
        Eigen::Transform<Scalar, Dim, Eigen::Affine>::Identity())
{
    static_assert(Dim == 2 || Dim == 3, "Unsupported dimension for mesh vertices.");

    const Index mesh_dim = mesh.get_dimension();
    la_runtime_assert(mesh_dim == Dim, "Mesh dimension does not match template dimension");

    const Index num_vertices = mesh.get_num_vertices();

    for (Index v = 0; v < num_vertices; ++v) {
        auto pos_span = mesh.get_position(v);

        if constexpr (Dim == 2) {
            Eigen::Matrix<Scalar, Dim, 1> p{pos_span[0], pos_span[1]};
            p = transform * p;
            fmt::print(output_stream, "v {} {}\n", p[0], p[1]);
        } else if constexpr (Dim == 3) {
            Eigen::Matrix<Scalar, Dim, 1> p{pos_span[0], pos_span[1], pos_span[2]};
            p = transform * p;
            fmt::print(output_stream, "v {} {} {}\n", p[0], p[1], p[2]);
        }
    }
}

template <typename Scalar, typename Index>
void write_mesh_facets(
    std::ostream& output_stream,
    const SurfaceMesh<Scalar, Index>& mesh,
    const AttributeWriteResult<Scalar, Index>& attr_result,
    Index vertex_offset = 0,
    Index uv_offset = 0,
    Index normal_offset = 0)
{
    const Index num_facets = mesh.get_num_facets();

    for (Index f = 0; f < num_facets; ++f) {
        const Index first_corner = mesh.get_facet_corner_begin(f);
        const auto vtx_indices = mesh.get_facet_vertices(f);
        la_runtime_assert(
            vtx_indices.size() >= 3,
            fmt::format("Mesh facet {} should have >= 3 vertices", f));
        output_stream << "f";
        for (Index lv = 0; lv < vtx_indices.size(); ++lv) {
            // vertex_index/texture_index/normal_index (OBJ indices are 1-based)
            Index v = vtx_indices[lv] + 1 + vertex_offset;
            Index vt =
                (!attr_result.uv_indices.empty() ? attr_result.uv_indices[first_corner + lv] : 0) +
                1 + uv_offset;
            Index vn =
                (!attr_result.normal_indices.empty() ? attr_result.normal_indices[first_corner + lv]
                                                     : 0) +
                1 + normal_offset;

            if (attr_result.uv_indices.empty() && attr_result.normal_indices.empty()) {
                fmt::print(output_stream, " {}", v);
            } else if (!attr_result.uv_indices.empty() && attr_result.normal_indices.empty()) {
                fmt::print(output_stream, " {}/{}", v, vt);
            } else if (!attr_result.uv_indices.empty() && !attr_result.normal_indices.empty()) {
                fmt::print(output_stream, " {}/{}/{}", v, vt, vn);
            } else if (attr_result.uv_indices.empty() && !attr_result.normal_indices.empty()) {
                fmt::print(output_stream, " {}//{}", v, vn);
            }
        }
        output_stream << "\n";
    }
}

template <typename Scalar, typename Index>
void write_texture_to_mtl(
    std::ostream& mtl_stream,
    const scene::Scene<Scalar, Index>& scene,
    const scene::TextureInfo& texture_info,
    const fs::path& base_dir,
    const std::string& map_directive)
{
    if (texture_info.index == scene::invalid_element) return;
    la_debug_assert(texture_info.index < scene.textures.size());

    const auto& texture = scene.textures[texture_info.index];
    la_debug_assert(texture.image != scene::invalid_element);
    la_debug_assert(texture.image < scene.images.size());

    const auto& image = scene.images[texture.image];
    fs::path image_filename;

    if (!image.image.data.empty()) {
        // Image data is available, save it to a file
        if (image.uri.empty()) {
            image_filename = fmt::format("texture_{}.png", texture_info.index);
        } else {
            image_filename = image.uri;
        }

        lagrange::image_io::save_image(
            base_dir / image_filename,
            image.image.data.data(),
            image.image.width,
            image.image.height,
            image::ImagePrecision::uint8,
            static_cast<lagrange::image::ImageChannel>(image.image.num_channels));

        // Write the texture map directive
        fmt::print(mtl_stream, "{} {}\n", map_directive, image_filename.string());
    } else if (!image.uri.empty()) {
        // No image data but URI exists, copy the file from URI
        fs::path source_path = image.uri;
        if (source_path.is_relative()) {
            // If URI is relative, assume it's relative to the base directory
            source_path = base_dir / source_path;
        }

        // Use the original filename from URI
        image_filename = fs::path(image.uri).filename();
        fs::path dest_path = base_dir / image_filename;

        // Copy the file if it exists
        if (fs::exists(source_path)) {
            fs::copy_file(source_path, dest_path, fs::copy_options::overwrite_existing);

            // Write the texture map directive
            fmt::print(mtl_stream, "{} {}\n", map_directive, image_filename.string());
        } else {
            throw std::runtime_error(
                fmt::format("Texture file not found: {}", source_path.string()));
        }
    } else {
        // Neither image data nor URI exists
        throw std::runtime_error(
            fmt::format("Texture {} has no image data and no URI", texture_info.index));
    }
}

template <typename Scalar, typename Index>
void write_mtl_file(const fs::path& mtl_filename, const scene::Scene<Scalar, Index>& scene)
{
    fs::ofstream mtl_stream(mtl_filename);
    if (!mtl_stream) {
        throw std::runtime_error(
            fmt::format("Failed to open MTL file for writing: {}", mtl_filename.string()));
    }

    const fs::path base_dir = mtl_filename.parent_path();

    fmt::print(mtl_stream, "# MTL File Generated by Lagrange\n");
    fmt::print(mtl_stream, "# Materials: {}\n\n", scene.materials.size());

    for (size_t mat_idx = 0; mat_idx < scene.materials.size(); ++mat_idx) {
        const auto& material = scene.materials[mat_idx];

        // Create a unique material name
        std::string mat_name =
            material.name.empty() ? fmt::format("material_{}", mat_idx) : material.name;

        fmt::print(mtl_stream, "newmtl {}\n", mat_name);

        // Note: PBR to Phong material conversion is not fully implemented
        // The following values provide basic material properties for compatibility

        // Use base color as diffuse color
        fmt::print(
            mtl_stream,
            "Kd {} {} {}\n",
            material.base_color_value[0],
            material.base_color_value[1],
            material.base_color_value[2]);

        // Use base color with reduced intensity for ambient
        fmt::print(
            mtl_stream,
            "Ka {} {} {}\n",
            material.base_color_value[0] * 0.1f,
            material.base_color_value[1] * 0.1f,
            material.base_color_value[2] * 0.1f);

        // Use low specular for non-metallic appearance
        fmt::print(mtl_stream, "Ks 0.04 0.04 0.04\n");

        // Set moderate shininess
        fmt::print(mtl_stream, "Ns 32\n");

        // Transparency (alpha)
        fmt::print(mtl_stream, "d {}\n", material.base_color_value[3]);

        // Standard illumination model
        fmt::print(mtl_stream, "illum 2\n");

        // Handle base color texture
        if (material.base_color_texture.index != scene::invalid_element) {
            write_texture_to_mtl(
                mtl_stream,
                scene,
                material.base_color_texture,
                base_dir,
                "map_Kd");
        }

        // Handle normal texture
        if (material.normal_texture.index != scene::invalid_element) {
            write_texture_to_mtl(mtl_stream, scene, material.normal_texture, base_dir, "map_Bump");
        }

        fmt::print(mtl_stream, "\n");
    }
}

template <typename Scalar, typename Index>
void save_scene_obj_impl(
    std::ostream& output_stream,
    const fs::path& obj_filename,
    const scene::Scene<Scalar, Index>& scene,
    const SaveOptions& options)
{
    // Check if we need to export materials
    bool has_materials = !scene.materials.empty();
    bool should_export_materials = options.export_materials && has_materials;
    bool is_stream_output = obj_filename.empty();

    // For stream output, we can't create external MTL files, so throw an exception
    if (should_export_materials && is_stream_output) {
        throw std::runtime_error(
            "Cannot export materials when saving to stream. "
            "Use file-based save_scene_obj() instead or set export_materials=false.");
    }

    // Count total vertices and facets across all mesh instances in the scene
    Index total_vertices = 0;
    Index total_facets = 0;
    Index total_instances = 0;

    // Walk through all nodes and count mesh instances
    for (const auto& node : scene.nodes) {
        for (const auto& mesh_instance : node.meshes) {
            la_debug_assert(mesh_instance.mesh != scene::invalid_element);
            la_debug_assert(mesh_instance.mesh < scene.meshes.size());
            const auto& mesh = scene.meshes[mesh_instance.mesh];
            total_vertices += mesh.get_num_vertices();
            total_facets += mesh.get_num_facets();
            total_instances++;
        }
    }

    // Write header
    write_obj_header<Scalar, Index>(output_stream, total_vertices, total_facets);

    // Write MTL file reference if materials are being exported
    if (should_export_materials) {
        fs::path mtl_filename = obj_filename;
        mtl_filename.replace_extension(".mtl");
        fmt::print(output_stream, "mtllib {}\n\n", mtl_filename.filename().string());

        // Write the MTL file
        write_mtl_file(mtl_filename, scene);
    }

    // Global offsets for proper indexing across all mesh instances
    Index vertex_offset = 0;
    Index uv_offset = 0;
    Index normal_offset = 0;

    // Function to process a node and its transform
    std::function<void(scene::ElementId, const Eigen::Transform<Scalar, 3, Eigen::Affine>&)>
        process_node = [&](scene::ElementId node_id,
                           const Eigen::Transform<Scalar, 3, Eigen::Affine>& parent_transform) {
            const auto& node = scene.nodes[node_id];

            // Compute the accumulated transform for this node
            Eigen::Transform<Scalar, 3, Eigen::Affine> node_transform =
                parent_transform * node.transform.template cast<Scalar>();

            // Process mesh instances in this node
            for (const auto& mesh_instance : node.meshes) {
                if (mesh_instance.mesh != scene::invalid_element &&
                    mesh_instance.mesh < scene.meshes.size()) {
                    const auto& mesh = scene.meshes[mesh_instance.mesh];

                    // Check mesh dimension
                    const Index dim = mesh.get_dimension();
                    if (dim != 2 && dim != 3) {
                        logger().warn("Skipping mesh with unsupported dimension: {}", dim);
                        continue;
                    }

                    std::string obj_name =
                        node.name.empty()
                            ? fmt::format("node_{}_mesh_{}", node_id, mesh_instance.mesh)
                            : fmt::format("{}_{}", node.name, mesh_instance.mesh);
                    fmt::print(output_stream, "o {}\n", obj_name);

                    // Set material if available
                    if (should_export_materials && !mesh_instance.materials.empty()) {
                        size_t mat_idx = mesh_instance.materials[0]; // Use first material
                        if (mat_idx < scene.materials.size()) {
                            const auto& material = scene.materials[mat_idx];
                            std::string mat_name = material.name.empty()
                                                       ? fmt::format("material_{}", mat_idx)
                                                       : material.name;
                            fmt::print(output_stream, "usemtl {}\n", mat_name);
                        }
                    }

                    // Write vertices
                    const Index num_vertices = mesh.get_num_vertices();
                    write_mesh_vertices<Scalar, Index, 3>(output_stream, mesh, node_transform);

                    // Write attributes and facets for this mesh instance
                    auto attr_result = write_mesh_attributes(output_stream, mesh, options);

                    // Write facets
                    write_mesh_facets(
                        output_stream,
                        mesh,
                        attr_result,
                        vertex_offset,
                        uv_offset,
                        normal_offset);

                    // Update offsets for next mesh instance
                    vertex_offset += num_vertices;
                    uv_offset += attr_result.uv_values_written;
                    normal_offset += attr_result.normal_values_written;
                }
            }

            // Process child nodes recursively
            for (scene::ElementId child_id : node.children) {
                if (child_id < scene.nodes.size()) {
                    process_node(child_id, node_transform);
                }
            }
        };

    // Process all root nodes
    for (scene::ElementId root_id : scene.root_nodes) {
        if (root_id < scene.nodes.size()) {
            process_node(root_id, Eigen::Transform<Scalar, 3, Eigen::Affine>::Identity());
        }
    }
}

} // anonymous namespace

// =====================================
// save_mesh_obj.h
// =====================================
template <typename Scalar, typename Index>
void save_mesh_obj(
    std::ostream& output_stream,
    const SurfaceMesh<Scalar, Index>& mesh,
    const SaveOptions& options)
{
    la_runtime_assert(output_stream, "Invalid output stream");

    const Index dim = mesh.get_dimension();
    la_runtime_assert(dim == 2 || dim == 3, "Mesh dimension should be 2 or 3");

    // Write header
    const Index num_vertices = mesh.get_num_vertices();
    const Index num_facets = mesh.get_num_facets();
    write_obj_header<Scalar, Index>(output_stream, num_vertices, num_facets);

    // Add object name for the mesh
    fmt::print(output_stream, "o mesh\n");

    // Write positions
    if (dim == 2) {
        write_mesh_vertices<Scalar, Index, 2>(output_stream, mesh);
    } else if (dim == 3) {
        write_mesh_vertices<Scalar, Index, 3>(output_stream, mesh);
    } else {
        throw std::runtime_error(fmt::format("Unsupported mesh dimension: {}", dim));
    }

    // Write normals and texcoords
    auto attr_result = write_mesh_attributes(output_stream, mesh, options);

    // Write facets
    write_mesh_facets(output_stream, mesh, attr_result);

    // TODO: Write edges
}

template <typename Scalar, typename Index>
void save_mesh_obj(
    const fs::path& filename,
    const SurfaceMesh<Scalar, Index>& mesh,
    const SaveOptions& options)
{
    fs::path parent_dir = filename.parent_path();
    if (!parent_dir.empty() && !fs::exists(parent_dir)) fs::create_directories(parent_dir);

    fs::ofstream output_stream(filename);
    if (!output_stream) {
        throw std::runtime_error(
            fmt::format("Failed to open OBJ file for writing: {}", filename.string()));
    }
    save_mesh_obj(output_stream, mesh, options);
}

#define LA_X_save_mesh_obj(_, Scalar, Index)    \
    template LA_IO_API void save_mesh_obj(      \
        std::ostream& output_stream,            \
        const SurfaceMesh<Scalar, Index>& mesh, \
        const SaveOptions& options);            \
    template LA_IO_API void save_mesh_obj(      \
        const fs::path& filename,               \
        const SurfaceMesh<Scalar, Index>& mesh, \
        const SaveOptions& options);
LA_SURFACE_MESH_X(save_mesh_obj, 0)
#undef LA_X_save_mesh_obj


// =====================================
// save_simple_scene_obj.h
// =====================================
template <typename Scalar, typename Index, size_t Dimension>
void save_simple_scene_obj(
    std::ostream& output_stream,
    const scene::SimpleScene<Scalar, Index, Dimension>& lscene,
    const SaveOptions& options)
{
    la_runtime_assert(output_stream, "Invalid output stream");

    // Count total vertices and facets across all instances
    Index total_vertices = 0;
    Index total_facets = 0;

    for (Index mesh_idx = 0; mesh_idx < lscene.get_num_meshes(); ++mesh_idx) {
        const auto& mesh = lscene.get_mesh(mesh_idx);
        const Index num_instances = lscene.get_num_instances(mesh_idx);
        total_vertices += mesh.get_num_vertices() * num_instances;
        total_facets += mesh.get_num_facets() * num_instances;
    }

    // Write header
    write_obj_header<Scalar, Index>(output_stream, total_vertices, total_facets);

    // Write comment about the scene structure
    fmt::print(
        output_stream,
        "# Simple scene with {} meshes and {} total instances\n",
        lscene.get_num_meshes(),
        lscene.compute_num_instances());

    Index vertex_offset = 0;
    Index uv_offset = 0;
    Index normal_offset = 0;

    // Process each mesh and its instances
    for (Index mesh_idx = 0; mesh_idx < lscene.get_num_meshes(); ++mesh_idx) {
        const auto& mesh = lscene.get_mesh(mesh_idx);

        // Check mesh dimension
        const Index dim = mesh.get_dimension();
        la_runtime_assert(dim == 2 || dim == 3, "Mesh dimension should be 2 or 3");

        // Track instance index for this mesh
        Index instance_idx = 0;

        // Process each instance of this mesh
        lscene.foreach_instances_for_mesh(mesh_idx, [&](const auto& instance) {
            fmt::print(output_stream, "o mesh_{}_instance_{}\n", mesh_idx, instance_idx);

            // Write transformed vertices
            const Index num_vertices = mesh.get_num_vertices();
            if constexpr (Dimension == 3) {
                write_mesh_vertices<Scalar, Index, 3>(output_stream, mesh, instance.transform);
            } else {
                write_mesh_vertices<Scalar, Index, 2>(output_stream, mesh, instance.transform);
            }

            // Write attributes for this instance
            auto attr_result = write_mesh_attributes(output_stream, mesh, options);

            // Write facets for this instance
            write_mesh_facets(
                output_stream,
                mesh,
                attr_result,
                vertex_offset,
                uv_offset,
                normal_offset);

            // Update offsets for next instance
            vertex_offset += num_vertices;
            uv_offset += attr_result.uv_values_written;
            normal_offset += attr_result.normal_values_written;

            instance_idx++;
        });
    }
}

template <typename Scalar, typename Index, size_t Dimension>
void save_simple_scene_obj(
    const fs::path& filename,
    const scene::SimpleScene<Scalar, Index, Dimension>& lscene,
    const SaveOptions& options)
{
    fs::path parent_dir = filename.parent_path();
    if (!parent_dir.empty() && !fs::exists(parent_dir)) fs::create_directories(parent_dir);

    fs::ofstream output_stream(filename);
    if (!output_stream) {
        throw std::runtime_error(
            fmt::format("Failed to open OBJ file for writing: {}", filename.string()));
    }
    save_simple_scene_obj(output_stream, lscene, options);
}

#define LA_X_save_simple_scene_obj(_, S, I, D)     \
    template LA_IO_API void save_simple_scene_obj( \
        const fs::path& filename,                  \
        const scene::SimpleScene<S, I, D>& scene,  \
        const SaveOptions& options);               \
    template LA_IO_API void save_simple_scene_obj( \
        std::ostream&,                             \
        const scene::SimpleScene<S, I, D>& scene,  \
        const SaveOptions& options);
LA_SIMPLE_SCENE_X(save_simple_scene_obj, 0);
#undef LA_X_save_simple_scene_obj

// =====================================
// save_scene_obj.h
// =====================================
template <typename Scalar, typename Index>
void save_scene_obj(
    std::ostream& output_stream,
    const scene::Scene<Scalar, Index>& scene,
    const SaveOptions& options)
{
    save_scene_obj_impl(output_stream, fs::path{}, scene, options);
}

template <typename Scalar, typename Index>
void save_scene_obj(
    const fs::path& filename,
    const scene::Scene<Scalar, Index>& scene,
    const SaveOptions& options)
{
    fs::path parent_dir = filename.parent_path();
    if (!parent_dir.empty() && !fs::exists(parent_dir)) fs::create_directories(parent_dir);

    fs::ofstream output_stream(filename);
    if (!output_stream) {
        throw std::runtime_error(
            fmt::format("Failed to open OBJ file for writing: {}", filename.string()));
    }
    save_scene_obj_impl(output_stream, filename, scene, options);
}

#define LA_X_save_scene_obj(_, S, I)        \
    template LA_IO_API void save_scene_obj( \
        std::ostream& output_stream,        \
        const scene::Scene<S, I>& scene,    \
        const SaveOptions& options);        \
    template LA_IO_API void save_scene_obj( \
        const fs::path& filename,           \
        const scene::Scene<S, I>& scene,    \
        const SaveOptions& options);
LA_SCENE_X(save_scene_obj, 0);
#undef LA_X_save_scene_obj

} // namespace io
} // namespace lagrange
